import numpy as np
import copy


def from_one_hot(one_hot):
    if np.sum(one_hot == 1) != 1 or np.sum(one_hot == 0) != len(one_hot) - 1:
        raise ValueError("input is not one-hot", one_hot)
    index = np.argmax(one_hot)  # 找到值为1的元素的索引
    return index


class GuessingNumber:
    def __init__(self, num_agents = 3):

        self.Radioland_Slim={
            0 : [1, 1, 1, 0, 1, 1, 1],
            1 : [0, 0, 1, 0, 0, 1, 0],
            2 : [1, 0, 1, 1, 1, 0, 1],
            3 : [1, 0, 1, 1, 0, 1, 1],
            4 : [0, 1, 1, 1, 0, 1, 0],
            5 : [1, 1, 0, 1, 0, 1, 1],
            6 : [1, 1, 0, 1, 1, 1, 1],
            7 : [1, 0, 1, 0, 0, 1, 0],
            8 : [1, 1, 1, 1, 1, 1, 1],
            9 : [1, 1, 1, 1, 0, 1, 1],
        }

        self.num_agents = num_agents

        self.hit_limit = 7 + 2 * (self.num_agents - 1)

        self.action_space = 10 + (self.num_agents - 1) * 7
        self.action_space_noop = 10 + (self.num_agents - 1) * 7 + 1

        self.guess_space_noop = 10 + 1 + 1
        self.reveal_space_noop = (self.num_agents - 1) * 7 + 1 + 1


        # self.observation_space = self.num_agents * 7 + (self.num_agents - 1) * 10 + self.num_agents
        # guessing_number2
        self.observation_space = 1 + self.num_agents * (1 + 1 + 3*7 + self.hit_limit + 1) + self.num_agents * (11 + 8)

        self.current_player = 0

        self.players = []
        for _ in range(num_agents):
            player_num = np.random.randint(0, 10)
            board = self.Radioland_Slim[player_num]
            initial_hint_board = [-1] * 7
            self.players.append({"number": player_num,
                                 "board": board,
                                 "hit": initial_hint_board,
                                 "guessed": 0,
                                 })

        self.done = False

        self.log = []



    def reset(self):
        self.current_player = 0

        for player in self.players:
            player["number"] = np.random.randint(0, 10)
            player["board"] = self.Radioland_Slim[player["number"]]
            player["hit"] = [-1] * 7
            player["guessed"] = 0
            player["hit_time"] = 0

        self.done = False
        self.current_player = 0

        obs = self.get_observations()
        available_actions = self.available_actions()

        return obs, available_actions

    def get_observations(self):
        obs = []
        for i in range(self.num_agents):
            o = []
            agent_id = self.num_agents * [0]
            agent_id[i] = 1
            o += agent_id # self.num_agents

            agent_id = self.num_agents * [0]
            agent_id[self.current_player] = 1
            o += agent_id # self.num_agents

            o += [int(self.done)] # 1

            for j in range(self.num_agents):
                for h in self.players[(i+j)%self.num_agents]["hit"]: # 3*7
                    if h == 0:
                        o += [1, 0, 0]
                    elif h == 1:
                        o += [0, 1, 0]
                    elif h == -1:
                        o += [0, 0, 1]
                    else:
                        raise ValueError(f"hit {h} not in [0,1,-1]")


            for j in range(self.num_agents):
                hit_time = [0] * self.hit_limit
                hit_time[self.players[(i+j)%self.num_agents]["hit_time"]] = 1
                o += hit_time # self.hit_limit

            for j in range(self.num_agents):
                # guessing_number2
                o += [self.players[(i+j)%self.num_agents]["guessed"]] # 1

            for j in range(self.num_agents):
                num = [0] * (10 + 1) # 11
                if j != 0:
                    num[self.players[(i+j)%self.num_agents]["number"]] = 1
                    o += num # 1
                else:
                    num[-1] = 1
                    o += num

            for j in range(self.num_agents):
                if j != 0: # 8
                    # guessing_number2
                    o += self.players[(i+j)%self.num_agents]["board"] # 7
                    o += [0]
                else:
                    o += [0] * 7
                    o += [1]

            obs.append(o)
        return obs

    def available_actions(self):
        available_actions = []
        for i, player in enumerate(self.players):
            if self.done:
                available_action = np.zeros(self.action_space_noop)
                available_action[-1] = 1
            else:
                if i == self.current_player:
                    available_action = np.ones(self.action_space_noop)
                    if player["guessed"] == 1:
                        available_action[:10] = 0
                    if player["hit_time"] >= self.hit_limit-1:
                        assert player["hit_time"] == self.hit_limit-1, (
                        "hit_time is too high", player["hit_time"].self.hit_limit)
                        available_action[10:] = 0
                    available_action[-1] = 0
                    assert sum(available_action) != 0, 'no available action'
                else:
                    available_action = np.zeros(self.action_space_noop)
                    available_action[-1] = 1
            available_actions.append(available_action)
        return available_actions

    def step(self, actions):
        if self.done:
            obs = self.get_observations()
            available_actions = self.available_actions()
            return obs, [0] * self.num_agents, self.done, available_actions

        pre_available_actions = self.available_actions()
        if isinstance(actions, list) or isinstance(actions, np.ndarray):
            assert len(actions) == self.num_agents
            for i, player in enumerate(self.players):
                if isinstance(actions[i], int) or isinstance(actions[i], np.int64):
                    assert pre_available_actions[i][actions[i]], (pre_available_actions[i], actions[i], "Invalid action!")
                    if i == self.current_player:
                        action = actions[i]
                elif isinstance(actions[i], list) or isinstance(actions[i], np.ndarray):
                    assert pre_available_actions[i][from_one_hot(actions[i])], (
                    pre_available_actions[i], from_one_hot(actions[i]), "Invalid action!")
                    if i == self.current_player:
                        action = from_one_hot(actions[i])
                else:
                    raise TypeError("Action should be one-hot list or int!", actions[i], type(actions[i]))

        else:
            raise TypeError(actions, " is not a list or np.ndarray!")

        player = self.players[self.current_player]
        a = [copy.deepcopy(self.players), self.current_player]


        if action < 10:  # 猜数字动作
            assert  player["guessed"] == 0
            if action == player["number"]:  # 猜对
                reward = 10
            else:  # 猜错
                reward = -1
                # reward = 0
            player["guessed"] = 1
            self.done = all(player["guessed"] == 1 for player in self.players)
        else:  # 揭示表盘动作
            reveal_index = (action - 10) % 7
            target_player = ((action - 10) // 7 + 1 + self.current_player) % self.num_agents
            reward = -0.1
            # reward = 0
            player["hit_time"] += 1
            assert player["hit_time"] < self.hit_limit, ("hit_time excess limit", player["hit_time"], self.hit_limit)

            self.players[target_player]["hit"][reveal_index] = self.players[target_player]["board"][reveal_index]

        self.current_player = (self.current_player + 1) % self.num_agents
        obs = self.get_observations()
        available_actions = self.available_actions()

        # seems to do not use as available_actions is always left at least a NOOP action
        for i, player in enumerate(self.players):
            if sum(pre_available_actions[i]) == 0:
                self.done = True

        a += [action, self.done, reward]
        self.log.append(a)

         # print(self.players, self.current_player, action, self.done, reward)

        return obs, [reward]*self.num_agents, self.done, available_actions

    def render(self):
        for i, player in enumerate(self.players):
            print(f"Player {i}: {player}")
